import numpy as np

from utils.accuracy import Accuracy

metric = Accuracy()


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


if __name__ == '__main__':
    results = metric.compute(references=[0, 1], predictions=[0, 1])
    print(results)
